# -*- coding: utf-8 -*-
"""
Created on Wed May 15 17:03:42 2019

@author: xiang_yaobing

1、实现模型层的可视化
2、实现预测热图
"""
import tensorflow as tf
from tensorflow.python.framework.ops import disable_eager_execution

disable_eager_execution()
import numpy as np
from PIL import Image
from tensorflow.keras.models import Model,load_model
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from tensorflow.keras import utils
#%%
#import tensorflow as tf
#
#img_path = '..\\testset\ClusterKDEpicture_multiKde\IC1311.r0.2.fits_loc.jpg'
#im = np.array(Image.open(img_path))#读图像
#plt.imshow(im, cmap=plt.cm.binary)
#plt.show()
#img = im.reshape(1,60,60,3).astype('float32')
#
#
#last_layer_name = 'conv2d_5'
#model_path = '..\\model\\s2_best_weights_v1.0.h5'
#model = load_model(model_path)
#print(model.summary())
##model = tf.keras.applications.resnet50.ResNet50()
#img_tensor = img
#
#conv_layer = model.get_layer('conv2d_5')
#
#heatmap_model = tf.keras.models.Model(
#    [model.inputs], [model.get_layer('conv2d_5').output, model.output]
#)
#
#with tf.GradientTape() as tape:
#    conv_output, predictions = heatmap_model(img_tensor)
#    loss = predictions[:, np.argmax(predictions[0])]
#
#grads = tape.gradient(loss, conv_output)
#
#plt.rcParams['figure.figsize'] = (5, 5)
#plt.imshow(np.reshape(grads[:,:,:,1],(15,15)), cmap=plt.cm.binary)
#plt.show()
#%%
def conv_output(model, layer_name, img):
    """Get the output of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           intermediate_output: feature map.
    """
    # this is the placeholder for the input images
    input_img = model.input

    try:
        # this is the placeholder for the conv output
        out_conv = model.get_layer(layer_name).output
    except:
        raise Exception('Not layer named {}!'.format(layer_name))

    # get the intermediate layer model
    intermediate_layer_model = Model(inputs=input_img, outputs=out_conv)

    # get the output of intermediate layer model
    intermediate_output = intermediate_layer_model.predict(img)

    return intermediate_output[0]


def conv_filter(model, layer_name, img):
    """Get the filter of conv layer.

    Args:
           model: keras model.
           layer_name: name of layer in the model.
           img: processed input image.

    Returns:
           filters.
    """
    # this is the placeholder for the input images
    input_img = model.input

    # get the symbolic outputs of each "key" layer (we gave them unique names).
    layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])

    try:
        layer_output = layer_dict[layer_name].output
    except:
        raise Exception('Not layer named {}!'.format(layer_name))

    kept_filters = []
    for i in range(layer_output.shape[-1]):
        loss = K.mean(layer_output[:, :, :, i])

        # compute the gradient of the input picture with this loss
        grads = K.gradients(loss, input_img)[0]

        # normalization trick: we normalize the gradient
        grads = utils.normalize(grads)

        # this function returns the loss and grads given the input picture
        iterate = K.function([input_img], [loss, grads])

        # step size for gradient ascent
        step = 1.
        # run gradient ascent for 20 steps
        fimg = img.copy()

        for j in range(40):
            loss_value, grads_value = iterate([fimg])
            fimg += grads_value * step

        # decode the resulting input image
        fimg = utils.deprocess_image(fimg[0])
        kept_filters.append((fimg, loss_value))

        # sort filter result
        kept_filters.sort(key=lambda x: x[1], reverse=True)

    return np.array([f[0] for f in kept_filters])

def output_heatmap(model, last_conv_layer, img):
    """Get the heatmap for image.

    Args:
           model: keras model.
           last_conv_layer: name of last conv layer in the model.
           img: processed input image.

    Returns:
           heatmap: heatmap.
    """
    # predict the image class
    preds = model.predict(img)
    # find the class index
    index = np.argmax(preds[0])
    # This is the entry in the prediction vector
    target_output = model.output[:, index]

    # get the last conv layer
    last_conv_layer = model.get_layer(last_conv_layer)

    # compute the gradient of the output feature map with this target class
    grads = K.gradients(target_output, last_conv_layer.output)[0]

    # mean the gradient over a specific feature map channel
    pooled_grads = K.mean(grads, axis=(0, 1, 2))

    # this function returns the output of last_conv_layer and grads 
    # given the input picture
    iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
    pooled_grads_value, conv_layer_output_value = iterate([img])

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the target class

    for i in range(conv_layer_output_value.shape[-1]):
        conv_layer_output_value[:, :, i] *= pooled_grads_value[i]

    # The channel-wise mean of the resulting feature map
    # is our heatmap of class activation
    heatmap = np.mean(conv_layer_output_value, axis=-1)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)

    return heatmap

img_path = '..\\testset\ClusterKDEpicture_multiKde\IC1311.r0.2.fits_loc.jpg'
im = np.array(Image.open(img_path))#读图像
plt.imshow(im, cmap=plt.cm.binary)
plt.show()
img = im.reshape(1,60,60,3).astype('float32')
model_path = '..\\model\\s2_best_weights_v1.0.h5'
model = load_model(model_path)
print(model.summary())
last_layer_name = 'conv2d_5'

'''
绘制预测热图
'''
heat_map = output_heatmap(model, last_layer_name, img)
plt.rcParams['figure.figsize'] = (5, 5)
plt.imshow(heat_map, cmap=plt.cm.binary)
plt.show()
#%%
'''
绘制某一卷积层的输出图像
'''
last_layer_name = 'conv2d_2'
plt.xticks([])
plt.yticks([])
plt.axis('off')
output = conv_output(model, last_layer_name, img)
#plt.rcParams['figure.figsize'] = (10, 10)
for i in range(16):
    info = output[::,::,i]
    plt.subplot(4,4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
    plt.imshow(info, cmap=plt.cm.hot)

plt.show()
img_path = '..\\testset\\noClusterKDEpicture_multiKde\\no_NGC_6152.csv_loc.jpg'
im = np.array(Image.open(img_path))#读图像
plt.imshow(im, cmap=plt.cm.binary)
plt.show()
img = im.reshape(1,60,60,3).astype('float32')
output = conv_output(model, last_layer_name, img)
#plt.rcParams['figure.figsize'] = (10, 10)
for i in range(16):
    info = output[::,::,i]
    plt.subplot(4,4,i+1)    
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
    plt.imshow(info, cmap=plt.cm.hot)
plt.show()